import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import datasets
from transformers import AutoImageProcessor, AutoModel


POOLING_TYPE = "cls_mean"  
# Options:
# "cls"        For using, CLS token only (768-dim)
# "mean"       For using, Mean of patch tokens (768-dim)
# "cls_mean"   For using, CLS + Mean concat (1536-dim)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model_name = "facebook/dinov2-base"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model.eval()

class ModeAwareImageFolder(datasets.ImageFolder):
    def __getitem__(self, index):
        path, label = self.samples[index]
        img = Image.open(path)

        original_mode = img.mode

        if img.mode != "RGB":
            img = img.convert("RGB")

        return img, label, original_mode, path


def pil_collate_fn(batch):
    images, labels, modes, paths = [], [], [], []

    for img, label, mode, path in batch:
        images.append(img)
        labels.append(label)
        modes.append(mode)
        paths.append(path)

    return images, torch.tensor(labels), modes, paths


def extract_features(dataloader, split_name="train"):
    all_features = []
    all_labels = []
    all_modes = []
    all_paths = []

    with torch.no_grad():
        for images, labels, modes, paths in tqdm(dataloader, desc=split_name):

            inputs = processor(images=images, return_tensors="pt").to(device)
            outputs = model(**inputs)

            tokens = outputs.last_hidden_state  

            if POOLING_TYPE == "cls":
                embeddings = tokens[:, 0, :]

            elif POOLING_TYPE == "mean":
                patch_tokens = tokens[:, 1:, :]
                embeddings = patch_tokens.mean(dim=1)

            elif POOLING_TYPE == "cls_mean":
                cls_token = tokens[:, 0, :]
                patch_mean = tokens[:, 1:, :].mean(dim=1)
                embeddings = torch.cat([cls_token, patch_mean], dim=1)

            else:
                raise ValueError("Pooling type selected is inavalid")

            all_features.append(embeddings.cpu().numpy())
            all_labels.append(labels.numpy())
            all_modes.extend(modes)
            all_paths.extend(paths)

    features = np.concatenate(all_features, axis=0)
    labels = np.concatenate(all_labels, axis=0)


    os.makedirs("features", exist_ok=True)

    npz_path = os.path.join(
        "features",
        f"{split_name}_dinov2_{POOLING_TYPE}.npz"
    )

    np.savez_compressed(
        npz_path,
        features=features,
        labels=labels,
        modes=np.array(all_modes),
        paths=np.array(all_paths)
    )

    print(f"\nSaved features to: {npz_path}")
    print("Feature shape:", features.shape)


def main():

    train_root = input("Path to TRAIN folder: ").strip()
    val_root = input("Path to TEST folder: ").strip()

    train_dataset = ModeAwareImageFolder(train_root)
    val_dataset = ModeAwareImageFolder(val_root)

    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=False,
        collate_fn=pil_collate_fn
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        collate_fn=pil_collate_fn
    )

    extract_features(train_loader, "train")
    extract_features(val_loader, "val")

if __name__ == "__main__":
    main()